#!/usr/bin/env python3
"""
Extract inner-most spans and their bounding boxes, and the MathML output,
from rendered LaTeX equations using Playwright and KaTeX.
Caching is maintained via a SHA1-based hash stored as a JSON file.

Requirements:
    pip install playwright
    python -m playwright install chromium

    Place katex.min.css and katex.min.js in the same directory as this script
"""

import hashlib
import json
import os
import pathlib
import re
import shutil
import threading
import unittest
from dataclasses import dataclass
from typing import List

from playwright.sync_api import Error as PlaywrightError
from playwright.sync_api import sync_playwright

# Thread-local storage for Playwright and browser instances
_thread_local = threading.local()

# Global cache lock to protect cache file operations.
cache_lock = threading.Lock()


@dataclass
class BoundingBox:
    x: float
    y: float
    width: float
    height: float


@dataclass
class SpanInfo:
    text: str
    bounding_box: BoundingBox


@dataclass
class RenderedEquation:
    mathml: str
    spans: List[SpanInfo]


def get_equation_hash(equation, bg_color="white", text_color="black", font_size=24):
    """
    Calculate SHA1 hash of the equation string and rendering parameters.
    """
    params_str = f"{equation}|{bg_color}|{text_color}|{font_size}"
    return hashlib.sha1(params_str.encode("utf-8")).hexdigest()


def get_cache_dir():
    """
    Get the cache directory for equations, creating it if it doesn't exist.
    """
    cache_dir = pathlib.Path.home() / ".cache" / "olmocr" / "bench" / "equations"
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir


def clear_cache_dir():
    """
    Clear all files and subdirectories in the cache directory.
    """
    cache_dir = get_cache_dir()
    if cache_dir.exists() and cache_dir.is_dir():
        shutil.rmtree(cache_dir)
        cache_dir.mkdir(parents=True, exist_ok=True)


def init_browser():
    """
    Initialize the Playwright and browser instance for the current thread if not already done.
    """
    if not hasattr(_thread_local, "playwright"):
        _thread_local.playwright = sync_playwright().start()
        _thread_local.browser = _thread_local.playwright.chromium.launch()


def get_browser():
    """
    Return the browser instance for the current thread.
    """
    init_browser()
    return _thread_local.browser


def render_equation(
    equation,
    bg_color="white",
    text_color="black",
    font_size=24,
    use_cache=True,
    debug_dom=False,
):
    """
    Render a LaTeX equation using Playwright and KaTeX, extract the inner-most span elements
    (those without child elements that contain non-whitespace text) along with their bounding boxes,
    and extract the MathML output generated by KaTeX.

    Returns:
        RenderedEquation: A dataclass containing the mathml string and a list of SpanInfo dataclasses.
    """
    # Calculate hash for caching
    eq_hash = get_equation_hash(equation, bg_color, text_color, font_size)
    cache_dir = get_cache_dir()
    cache_file = cache_dir / f"{eq_hash}.json"
    cache_error_file = cache_dir / f"{eq_hash}_error"

    # Use lock to ensure thread-safe cache file access
    with cache_lock:
        if use_cache:
            if cache_error_file.exists():
                return None
            if cache_file.exists():
                with open(cache_file, "r") as f:
                    data = json.load(f)
                spans = [
                    SpanInfo(
                        text=s["text"],
                        bounding_box=BoundingBox(
                            x=s["boundingBox"]["x"],
                            y=s["boundingBox"]["y"],
                            width=s["boundingBox"]["width"],
                            height=s["boundingBox"]["height"],
                        ),
                    )
                    for s in data["spans"]
                ]
                return RenderedEquation(mathml=data["mathml"], spans=spans)

    # Escape the equation for use in a JavaScript string.
    escaped_equation = json.dumps(equation)

    # Get local paths for KaTeX files
    script_dir = os.path.dirname(os.path.abspath(__file__))
    katex_css_path = os.path.join(script_dir, "katex.min.css")
    katex_js_path = os.path.join(script_dir, "katex.min.js")

    if not os.path.exists(katex_css_path) or not os.path.exists(katex_js_path):
        raise FileNotFoundError(f"KaTeX files not found. Please ensure katex.min.css and katex.min.js are in {script_dir}")

    # Get the browser instance for the current thread.
    browser = get_browser()

    # Create a new page.
    page = browser.new_page(viewport={"width": 800, "height": 400})

    # Basic HTML structure for rendering.
    page_html = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <style>
            body {{
                display: flex;
                justify-content: center;
                align-items: center;
                height: 100vh;
                margin: 0;
                background-color: {bg_color};
                color: {text_color};
            }}
            #equation-container {{
                padding: 0;
                font-size: {font_size}px;
            }}
        </style>
    </head>
    <body>
        <div id="equation-container"></div>
    </body>
    </html>
    """
    page.set_content(page_html)
    page.add_style_tag(path=katex_css_path)
    page.add_script_tag(path=katex_js_path)
    page.wait_for_load_state("networkidle")

    katex_loaded = page.evaluate("typeof katex !== 'undefined'")
    if not katex_loaded:
        page.close()
        raise RuntimeError("KaTeX library failed to load. Check your katex.min.js file.")

    try:
        error_message = page.evaluate(
            f"""
        () => {{
            try {{
                katex.render({escaped_equation}, document.getElementById("equation-container"), {{
                    displayMode: true,
                    throwOnError: true
                }});
                return null;
            }} catch (error) {{
                console.error("KaTeX error:", error.message);
                return error.message;
            }}
        }}
        """
        )
    except PlaywrightError as ex:
        print(escaped_equation)
        error_message = str(ex)
        page.close()
        raise

    if error_message:
        print(f"Error rendering equation: '{equation}'")
        print(error_message)
        with cache_lock:
            cache_error_file.touch()
        page.close()
        return None

    page.wait_for_selector(".katex", state="attached")

    if debug_dom:
        katex_dom_html = page.evaluate(
            """
        () => {
            return document.getElementById("equation-container").innerHTML;
        }
        """
        )
        print("\n===== KaTeX DOM HTML =====")
        print(katex_dom_html)

    # Extract inner-most spans with non-whitespace text.
    spans_info = page.evaluate(
        """
    () => {
        const spans = Array.from(document.querySelectorAll('span'));
        const list = [];
        spans.forEach(span => {
            if (span.children.length === 0 && /\S/.test(span.textContent)) {
                const rect = span.getBoundingClientRect();
                list.push({
                    text: span.textContent.trim(),
                    boundingBox: {
                        x: rect.x,
                        y: rect.y,
                        width: rect.width,
                        height: rect.height
                    }
                });
            }
        });
        return list;
    }
    """
    )

    if debug_dom:
        print("\n===== Extracted Span Information =====")
        print(spans_info)

    # Extract MathML output (if available) from the KaTeX output.
    mathml = page.evaluate(
        """
    () => {
        const mathElem = document.querySelector('.katex-mathml math');
        return mathElem ? mathElem.outerHTML : "";
    }
    """
    )

    page.close()

    rendered_eq = RenderedEquation(
        mathml=mathml,
        spans=[
            SpanInfo(
                text=s["text"],
                bounding_box=BoundingBox(x=s["boundingBox"]["x"], y=s["boundingBox"]["y"], width=s["boundingBox"]["width"], height=s["boundingBox"]["height"]),
            )
            for s in spans_info
        ],
    )

    # Cache the rendered equation.
    cache_data = {
        "mathml": rendered_eq.mathml,
        "spans": [
            {
                "text": span.text,
                "boundingBox": {"x": span.bounding_box.x, "y": span.bounding_box.y, "width": span.bounding_box.width, "height": span.bounding_box.height},
            }
            for span in rendered_eq.spans
        ],
    }
    with cache_lock:
        with open(cache_file, "w") as f:
            json.dump(cache_data, f)
    return rendered_eq


def compare_rendered_equations(reference: RenderedEquation, hypothesis: RenderedEquation) -> bool:
    """
    Compare two RenderedEquation objects. First, check if the normalized MathML of the hypothesis
    is contained within that of the reference. If not, perform a neighbor-based matching on the spans.
    """
    from bs4 import BeautifulSoup

    def extract_inner(mathml: str) -> str:
        try:
            soup = BeautifulSoup(mathml, "xml")
            semantics = soup.find("semantics")
            if semantics:
                inner_parts = [str(child) for child in semantics.contents if getattr(child, "name", None) != "annotation"]
                return "".join(inner_parts)
            else:
                return str(soup)
        except Exception as e:
            print("Error parsing MathML with BeautifulSoup:", e)
            print(mathml)
            return mathml

    def normalize(s: str) -> str:
        return re.sub(r"\s+", "", s)

    reference_inner = normalize(extract_inner(reference.mathml))
    hypothesis_inner = normalize(extract_inner(hypothesis.mathml))
    if reference_inner in hypothesis_inner:
        return True

    H, R = reference.spans, hypothesis.spans
    H = [span for span in H if span.text != "\u200b"]
    R = [span for span in R if span.text != "\u200b"]

    candidate_map = {}
    for i, hspan in enumerate(H):
        candidate_map[i] = [j for j, rsp in enumerate(R) if rsp.text == hspan.text]
        if not candidate_map[i]:
            return False

    def compute_neighbors(spans, tol=5):
        neighbors = {}
        for i, span in enumerate(spans):
            cx = span.bounding_box.x + span.bounding_box.width / 2
            cy = span.bounding_box.y + span.bounding_box.height / 2
            up = down = left = right = None
            up_dist = down_dist = left_dist = right_dist = None
            for j, other in enumerate(spans):
                if i == j:
                    continue
                ocx = other.bounding_box.x + other.bounding_box.width / 2
                ocy = other.bounding_box.y + other.bounding_box.height / 2
                if ocy < cy and abs(ocx - cx) <= tol:
                    dist = cy - ocy
                    if up is None or dist < up_dist:
                        up = j
                        up_dist = dist
                if ocy > cy and abs(ocx - cx) <= tol:
                    dist = ocy - cy
                    if down is None or dist < down_dist:
                        down = j
                        down_dist = dist
                if ocx < cx and abs(ocy - cy) <= tol:
                    dist = cx - ocx
                    if left is None or dist < left_dist:
                        left = j
                        left_dist = dist
                if ocx > cx and abs(ocy - cy) <= tol:
                    dist = ocx - cx
                    if right is None or dist < right_dist:
                        right = j
                        right_dist = dist
            neighbors[i] = {"up": up, "down": down, "left": left, "right": right}
        return neighbors

    hyp_neighbors = compute_neighbors(H)
    ref_neighbors = compute_neighbors(R)

    n = len(H)
    used = [False] * len(R)
    assignment = {}

    def backtrack(i):
        if i == n:
            return True
        for cand in candidate_map[i]:
            if used[cand]:
                continue
            assignment[i] = cand
            used[cand] = True
            valid = True
            for direction in ["up", "down", "left", "right"]:
                hyp_nb = hyp_neighbors[i].get(direction)
                ref_nb = ref_neighbors[cand].get(direction)
                if hyp_nb is not None:
                    expected_text = H[hyp_nb].text
                    if ref_nb is None:
                        valid = False
                        break
                    if hyp_nb in assignment:
                        if assignment[hyp_nb] != ref_nb:
                            valid = False
                            break
                    else:
                        if R[ref_nb].text != expected_text:
                            valid = False
                            break
            if valid:
                if backtrack(i + 1):
                    return True
            used[cand] = False
            del assignment[i]
        return False

    return backtrack(0)


class TestRenderedEquationComparison(unittest.TestCase):
    def test_exact_match(self):
        eq1 = render_equation("a+b", use_cache=False)
        eq2 = render_equation("a+b", use_cache=False)
        self.assertTrue(compare_rendered_equations(eq1, eq2))

    def test_whitespace_difference(self):
        eq1 = render_equation("a+b", use_cache=False)
        eq2 = render_equation("a + b", use_cache=False)
        self.assertTrue(compare_rendered_equations(eq1, eq2))

    def test_not_found(self):
        eq1 = render_equation("c-d", use_cache=False)
        eq2 = render_equation("a+b", use_cache=False)
        self.assertFalse(compare_rendered_equations(eq1, eq2))

    def test_align_block_contains_needle(self):
        eq_plain = render_equation("a+b", use_cache=False)
        eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
        self.assertTrue(compare_rendered_equations(eq_plain, eq_align))

    def test_align_block_needle_not_in(self):
        eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
        eq_diff = render_equation("c-d", use_cache=False)
        self.assertFalse(compare_rendered_equations(eq_diff, eq_align))

    def test_big(self):
        ref_rendered = render_equation("\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}", use_cache=False, debug_dom=False)
        align_rendered = render_equation(
            """\\begin{align*}\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}\\end{align*}""", use_cache=False, debug_dom=False
        )
        self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))

    def test_dot_end1(self):
        ref_rendered = render_equation(
            "\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right]"
        )
        align_rendered = render_equation(
            "\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right]."
        )
        self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))

    def test_dot_end2(self):
        ref_rendered = render_equation(
            "\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\psi(g s)}\\right]"
        )
        align_rendered = render_equation(
            "\\lambda_g = \\sum_{s \\in S} \\zeta_n^{\\psi(gs)} = \\sum_{i=1}^{k} \\left[ \\sum_{s, Rs = \\mathcal{I}_i} \\zeta_n^{\\psi(gs)} \\right]"
        )
        self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))

    def test_lambda(self):
        ref_rendered = render_equation("\\lambda_g = \\lambda_{g'}")
        align_rendered = render_equation("\\lambda_{g}=\\lambda_{g^{\\prime}}")
        self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))

    def test_gemini(self):
        ref_rendered = render_equation("u \\in (R/\\operatorname{Ann}_R(x_i))^{\\times}")
        align_rendered = render_equation("u \\in\\left(R / \\operatorname{Ann}_{R}\\left(x_{i}\\right)\\right)^{\\times}")
        self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))


if __name__ == "__main__":
    unittest.main()
